source('codes/00_import_libraries.R')


#------------------------------------------------------------------------
# Specifications
#------------------------------------------------------------------------

PLS    = 'PLS_All_recursive'
x_list = c('War', PLS)
yvar   = 'mkt'

h = c(1,3,6,12,24,36)
row_names = paste0('$h$ = ', h)

year_grid = data.frame(
  beg = c(1871, 1871, 1950, 2000),
  end = c(2019, 1949, 2019, 2019)
)

period_list = paste0(
  year_grid$beg, '-', year_grid$end
)



#------------------------------------------------------------------------
# Data
#------------------------------------------------------------------------

dict_name = 'base'
dt = readRDS('input/combined_data.rds')[[dict_name]] %>%
  mutate(
    year = year(ym),
    across(all_of(paste0(yvar, h)), ~ .*12),
  ) |> 
  drop_na(mkt1) %>% 
  select(ym, year, contains(yvar), any_of(x_list))



#------------------------------------------------------------------------
# Functions
#------------------------------------------------------------------------

ols_res = function(xvar, yvar, h, yr_range) {
  # this function takes in the xvar and year and fit ols to all returns
  
  # fit an ols model to all returns 
  mod_list = lapply(paste0(yvar, h),
    fn_fit_ols,
    xvar = xvar,
    df   = dt %>% filter(year %in% yr_range)
  )
  
  # compute newey west se
  mod_robust = mapply(
    function(mod, h) {
      coeftest(mod, vcov. = NeweyWest(mod, lag = h, prewhite = F))
    },
    mod_list,
    h,
    SIMPLIFY = FALSE
  )
  
  # return the adjusted r2
  is_r2 = lapply(mod_list, function(mod) glance(mod)$adj.r.squared*100) %>% unlist()
  nobs  = lapply(mod_list, function(mod) glance(mod)$nobs) %>% unlist()

  # create huxtale from regression
  out = huxreg(
    mod_robust,
    coefs        = xvar,
    error_format = "({statistic})",
    stars        = c(`***` = 0.01, `**` = 0.05, `*` = 0.1),
    statistics   = c(0),
    borders      = 0,
    outer_borders= 0,
    note         = NULL
  )

  # add the r2 to the huxtable
  out = t(out)[-1,-1] %>%
    cbind(is_r2)
  
  return(out)
}



#------------------------------------------------------------------------
# Results
#------------------------------------------------------------------------

column_names = c('','War','\\textit{t}-stat','$R^2$', 'PLS','\\textit{t}-stat','$R^2$')

output = lapply(1 : nrow(year_grid), function(i) {
      
    # apply the ols_fit to each xvar
    lapply(
        x_list,
        ols_res,
        yvar       = yvar,
        h          = h,
        yr_range   = year_grid$beg[i]:year_grid$end[i]
      ) %>%
      
      do.call(cbind, .) %>%
      insert_column(row_names) %>%
      insert_row(period_list[i], fill = '') %>%
      merge_cells(1, everywhere) %>%
      set_number_format(1, everywhere, 0) %>%
      set_align(1, everywhere, 'center') %>% 
      set_tb_borders(1, everywhere, 0.8) %>%
      set_bottom_border(final(1), everywhere, 0.8)
    }
  ) %>%
  
  do.call(rbind, .) %>%
  set_number_format(everywhere, -1, 2) %>%
  insert_row(column_names) %>%
  set_tb_borders(1, everywhere, 0.8) %>%
  set_align(-1, -1, '.') %>%
  set_align(1, everywhere, 'center') %>%
  
  set_tb_padding(0) %>%
  set_lr_padding(0) %>% 
  set_escape_contents(F)



#------------------------------------------------------------------------
# Export
#------------------------------------------------------------------------

quick_xlsx(output, file = glue('output/tables/Table_4.xlsx'))
# cat(to_latex(output, tabular_only = TRUE), file = glue('output/tables/Table_4.tex'))
